package com.yeskery.nut.extend.mybatis;

import com.yeskery.nut.bean.ApplicationContext;
import com.yeskery.nut.bean.NoSuchBeanException;
import com.yeskery.nut.core.NutException;
import com.yeskery.nut.plugin.ApplicationContextPluginPostProcessor;
import com.yeskery.nut.plugin.NutApplicationSupportBasePlugin;
import com.yeskery.nut.plugin.PluginBeanMetadata;
import com.yeskery.nut.transaction.TransactionManager;
import com.yeskery.nut.transaction.TransactionRegistry;
import com.yeskery.nut.util.ReflectUtils;
import org.apache.ibatis.logging.jdk14.Jdk14LoggingImpl;
import org.apache.ibatis.mapping.Environment;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.SqlSessionFactory;

import javax.sql.DataSource;
import java.lang.reflect.Proxy;
import java.util.*;

/**
 * 基础MyBatis插件
 * @author sprout
 * @version 1.0
 * 2022-08-25 21:29
 */
public class BaseMyBatisPlugin extends NutApplicationSupportBasePlugin implements ApplicationContextPluginPostProcessor {

    /** SQL会话工厂 */
    private SqlSessionFactory sqlSessionFactory;

    /** 数据源 */
    private DataSource dataSource;

    /**
     * 构建基础MyBatis插件
     * @param sqlSessionFactory SQL会话工厂
     */
    public BaseMyBatisPlugin(SqlSessionFactory sqlSessionFactory) {
        if (sqlSessionFactory == null) {
            throw new NutException("MyBatis SqlSessionFactory Must No Be Null.");
        }
        this.sqlSessionFactory = sqlSessionFactory;
        sqlSessionFactory.getConfiguration().setLogImpl(Jdk14LoggingImpl.class);
    }

    /**
     * 构建基础MyBatis插件
     * @param sqlSessionFactory SQL会话工厂
     * @param dataSource 数据源
     */
    public BaseMyBatisPlugin(SqlSessionFactory sqlSessionFactory, DataSource dataSource) {
        this(sqlSessionFactory);
        this.dataSource = dataSource;
    }

    /**
     * 构建基础MyBatis插件
     */
    public BaseMyBatisPlugin() {
    }

    /**
     * 获取SQL会话工厂
     * @return SQL会话工厂
     */
    public SqlSessionFactory getSqlSessionFactory() {
        return sqlSessionFactory;
    }

    /**
     * 设置SQL会话工厂
     * @param sqlSessionFactory 会话工厂
     */
    public void setSqlSessionFactory(SqlSessionFactory sqlSessionFactory) {
        this.sqlSessionFactory = sqlSessionFactory;
    }

    /**
     * 获取配置对象
     * @return 配置对象
     */
    public Configuration getConfiguration() {
        return Optional.ofNullable(getSqlSessionFactory())
                .map(SqlSessionFactory::getConfiguration)
                .orElse(null);
    }

    /**
     * 获取数据源
     * @return 数据源
     */
    public DataSource getDataSource() {
        return Optional.ofNullable(getConfiguration())
                .map(Configuration::getEnvironment)
                .map(Environment::getDataSource)
                .orElse(null);
    }

    @Override
    protected Collection<PluginBeanMetadata> getRegisterPluginBeanMetadata() {
        List<PluginBeanMetadata> pluginBeanMetadataList = new LinkedList<>();
        pluginBeanMetadataList.add(new PluginBeanMetadata("sqlSessionFactory", getSqlSessionFactory(), SqlSessionFactory.class));
        pluginBeanMetadataList.addAll(getExtendBeanMetadata());
        return pluginBeanMetadataList;
    }

    @Override
    public void process(ApplicationContext applicationContext) throws Exception {
        if (dataSource == null) {
            try {
                dataSource = applicationContext.getBean(DataSource.class);
            } catch (NoSuchBeanException e) {
                throw new NutException("MyBatisPlugin Is Required A DataSource.", e);
            }
        }
        try {
            TransactionRegistry transactionRegistry = applicationContext.getBean(TransactionRegistry.class);
            setTransactionManagerAndDataSource(transactionRegistry.getTransactionManager(), dataSource);
        } catch (NoSuchBeanException e) {
            throw new NutException("MyBatisPlugin Is Required A TransactionManager.", e);
        }
    }

    /**
     * 设置数据源
     * @param dataSource 数据源
     */
    protected void setDataSource(DataSource dataSource) {
        this.dataSource = dataSource;
    }

    /**
     * 设置事务管理器和数据源
     * @param transactionManager 事务管理器
     * @param dataSource 数据源
     */
    protected void setTransactionManagerAndDataSource(TransactionManager transactionManager, DataSource dataSource) {
        Configuration configuration = getSqlSessionFactory().getConfiguration();
        Environment environment = configuration.getEnvironment();
        if (environment == null) {
            configuration.setEnvironment(new Environment.Builder("default")
                    .transactionFactory(new NutTransactionFactory(transactionManager)).dataSource(warpDataSource(dataSource)).build());
        } else {
            configuration.setEnvironment(new Environment.Builder(environment.getId())
                    .transactionFactory(new NutTransactionFactory(transactionManager)).dataSource(warpDataSource(dataSource)).build());
        }
    }

    /**
     * 包装数据源
     * @param dataSource 数据源
     * @return 包装后的数据源
     */
    protected DataSource warpDataSource(DataSource dataSource) {
        return dataSource;
    }

    /**
     * 获取扩展Bean元数据集合
     * @return 扩展Bean元数据集合
     */
    protected Collection<PluginBeanMetadata> getExtendBeanMetadata() {
        TransactionRegistry transactionRegistry = getApplicationContext().getBean(TransactionRegistry.class);
        Collection<Class<?>> mappers = getSqlSessionFactory().getConfiguration().getMapperRegistry().getMappers();
        List<PluginBeanMetadata> pluginBeanMetadataList = new ArrayList<>(mappers.size());
        for (Class<?> mapper : mappers) {
            if (mapper.isInterface()) {
                Object proxyMapper = Proxy.newProxyInstance(this.getClass().getClassLoader(), new Class[]{mapper},
                        new MapperProxyInvocationHandler(transactionRegistry.getTransactionManager(), getSqlSessionFactory(), mapper));
                pluginBeanMetadataList.add(new PluginBeanMetadata(ReflectUtils.getDefaultBeanName(mapper), proxyMapper, mapper));
            }
        }
        return pluginBeanMetadataList;
    }
}
